{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Imdb 文本建模.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "g8SIGld3s7Bt"
      },
      "source": [
        "# 文本数据建模\n",
        "\n",
        "Imdb 数据集的目标是根据电影评论的文本内容预测评论的情感标签。训练集有20000条电影评论文本，测试集有5000条电影评论文本，其中正面评论和负面评论都各占一半。文本数据预处理较为繁琐，包括中文切词（本示例不涉及），构建词典，编码转换，序列填充，构建数据管道等等。\n",
        "\n",
        "在 TensorFlow 中完成文本数据预处理的常用方案有两种：\n",
        "\n",
        "- 第一种是利用 tf.keras.preprocessing 中的 Tokenizer 词典构建工具和 tf.keras.utils.Sequence 构建文本数据生成器管道。\n",
        "\n",
        "- 第二种是使用 tf.data.Dataset 搭配 tf.keras.layers.experimental.preprocessing.TextVectorization 预处理层。\n",
        "\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "YyFa8YbVtdyY",
        "outputId": "616803bd-97e6-40a9-9272-c7d78f06845d",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 411
        }
      },
      "source": [
        "# 下载数据\n",
        "!wget https://raw.githubusercontent.com/lyhue1991/eat_tensorflow2_in_30_days/master/data/imdb/train.csv -O ./imdb_train.csv\n",
        "\n",
        "!wget https://raw.githubusercontent.com/lyhue1991/eat_tensorflow2_in_30_days/master/data/imdb/test.csv -O ./imdb_test.csv"
      ],
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "--2020-10-06 07:06:01--  https://raw.githubusercontent.com/lyhue1991/eat_tensorflow2_in_30_days/master/data/imdb/train.csv\n",
            "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...\n",
            "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 26683623 (25M) [text/plain]\n",
            "Saving to: ‘./imdb_train.csv’\n",
            "\n",
            "./imdb_train.csv    100%[===================>]  25.45M  35.3MB/s    in 0.7s    \n",
            "\n",
            "2020-10-06 07:06:02 (35.3 MB/s) - ‘./imdb_train.csv’ saved [26683623/26683623]\n",
            "\n",
            "--2020-10-06 07:06:02--  https://raw.githubusercontent.com/lyhue1991/eat_tensorflow2_in_30_days/master/data/imdb/test.csv\n",
            "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...\n",
            "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 6638231 (6.3M) [text/plain]\n",
            "Saving to: ‘./imdb_test.csv’\n",
            "\n",
            "./imdb_test.csv     100%[===================>]   6.33M  15.7MB/s    in 0.4s    \n",
            "\n",
            "2020-10-06 07:06:03 (15.7 MB/s) - ‘./imdb_test.csv’ saved [6638231/6638231]\n",
            "\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "nivAmS5gtHy-",
        "outputId": "36e3de21-29aa-4cdd-d573-f7f1b3ded8dc",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 54
        }
      },
      "source": [
        "import numpy as np \n",
        "import pandas as pd \n",
        "from matplotlib import pyplot as plt\n",
        "import tensorflow as tf\n",
        "from tensorflow.keras import models,layers,preprocessing,optimizers,losses,metrics\n",
        "from tensorflow.keras.layers.experimental.preprocessing import TextVectorization\n",
        "import re,string\n",
        "\n",
        "train_data_path = \"./imdb_train.csv\"\n",
        "test_data_path =  \"./imdb_test.csv\"\n",
        "\n",
        "MAX_WORDS = 10000  # 仅考虑最高频的10000个词\n",
        "MAX_LEN = 200  # 每个样本保留200个词的长度\n",
        "BATCH_SIZE = 20 \n",
        "\n",
        "# 构建管道\n",
        "def split_line(line):\n",
        "    arr = tf.strings.split(line,\"\\t\")\n",
        "    label = tf.expand_dims(tf.cast(tf.strings.to_number(arr[0]),tf.int32),axis = 0)\n",
        "    text = tf.expand_dims(arr[1],axis = 0)\n",
        "    return (text,label)\n",
        "\n",
        "ds_train_raw = tf.data.TextLineDataset(filenames = [train_data_path]) \\\n",
        "   .map(split_line,num_parallel_calls = tf.data.experimental.AUTOTUNE) \\\n",
        "   .shuffle(buffer_size = 1000).batch(BATCH_SIZE) \\\n",
        "   .prefetch(tf.data.experimental.AUTOTUNE)\n",
        "\n",
        "ds_test_raw = tf.data.TextLineDataset(filenames = [test_data_path]) \\\n",
        "   .map(split_line,num_parallel_calls = tf.data.experimental.AUTOTUNE) \\\n",
        "   .batch(BATCH_SIZE) \\\n",
        "   .prefetch(tf.data.experimental.AUTOTUNE)\n",
        "\n",
        "\n",
        "# 构建词典\n",
        "def clean_text(text):\n",
        "    lowercase = tf.strings.lower(text)\n",
        "    stripped_html = tf.strings.regex_replace(lowercase, '<br />', ' ')\n",
        "    cleaned_punctuation = tf.strings.regex_replace(stripped_html,\n",
        "         '[%s]' % re.escape(string.punctuation),'')\n",
        "    return cleaned_punctuation\n",
        "\n",
        "vectorize_layer = TextVectorization(\n",
        "    standardize=clean_text,\n",
        "    split = 'whitespace',\n",
        "    max_tokens=MAX_WORDS-1, #有一个留给占位符\n",
        "    output_mode='int',\n",
        "    output_sequence_length=MAX_LEN)\n",
        "\n",
        "ds_text = ds_train_raw.map(lambda text,label: text)\n",
        "vectorize_layer.adapt(ds_text)\n",
        "\n",
        "print(vectorize_layer.get_vocabulary()[0:100])\n",
        "\n",
        "\n",
        "#单词编码\n",
        "ds_train = ds_train_raw.map(lambda text,label:(vectorize_layer(text),label)) \\\n",
        "    .prefetch(tf.data.experimental.AUTOTUNE)\n",
        "ds_test = ds_test_raw.map(lambda text,label:(vectorize_layer(text),label)) \\\n",
        "    .prefetch(tf.data.experimental.AUTOTUNE)"
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "['', '[UNK]', 'the', 'and', 'a', 'of', 'to', 'is', 'in', 'it', 'i', 'this', 'that', 'was', 'as', 'for', 'with', 'movie', 'but', 'film', 'on', 'not', 'you', 'his', 'are', 'have', 'be', 'he', 'one', 'its', 'at', 'all', 'by', 'an', 'they', 'from', 'who', 'so', 'like', 'her', 'just', 'or', 'about', 'has', 'if', 'out', 'some', 'there', 'what', 'good', 'more', 'when', 'very', 'she', 'even', 'my', 'no', 'would', 'up', 'time', 'only', 'which', 'story', 'really', 'their', 'were', 'had', 'see', 'can', 'me', 'than', 'we', 'much', 'well', 'get', 'been', 'will', 'into', 'people', 'also', 'other', 'do', 'bad', 'because', 'great', 'first', 'how', 'him', 'most', 'dont', 'made', 'then', 'them', 'films', 'movies', 'way', 'make', 'could', 'too', 'any']\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7TOuR2FZvQyE"
      },
      "source": [
        "# 定义模型\n",
        "\n",
        "使用Keras接口有以下3种方式构建模型：使用Sequential按层顺序构建模型，使用函数式API构建任意结构模型，继承Model基类构建自定义模型。此处选择使用继承Model基类构建自定义模型。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "--9i4DvPvXdc",
        "outputId": "5d571bea-565b-4697-999d-389659b281c3",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 425
        }
      },
      "source": [
        "# 演示自定义模型范例，实际上应该优先使用Sequential或者函数式API\n",
        "tf.keras.backend.clear_session()\n",
        "\n",
        "class CnnModel(models.Model):\n",
        "    def __init__(self):\n",
        "        super(CnnModel, self).__init__()\n",
        "        \n",
        "    def build(self,input_shape):\n",
        "        self.embedding = layers.Embedding(MAX_WORDS,7,input_length=MAX_LEN)\n",
        "        self.conv_1 = layers.Conv1D(16, kernel_size= 5,name = \"conv_1\",activation = \"relu\")\n",
        "        self.pool_1 = layers.MaxPool1D(name = \"pool_1\")\n",
        "        self.conv_2 = layers.Conv1D(128, kernel_size=2,name = \"conv_2\",activation = \"relu\")\n",
        "        self.pool_2 = layers.MaxPool1D(name = \"pool_2\")\n",
        "        self.flatten = layers.Flatten()\n",
        "        self.dense = layers.Dense(1,activation = \"sigmoid\")\n",
        "        super(CnnModel,self).build(input_shape)\n",
        "    \n",
        "    def call(self, x):\n",
        "        x = self.embedding(x)\n",
        "        x = self.conv_1(x)\n",
        "        x = self.pool_1(x)\n",
        "        x = self.conv_2(x)\n",
        "        x = self.pool_2(x)\n",
        "        x = self.flatten(x)\n",
        "        x = self.dense(x)\n",
        "        return(x)\n",
        "    \n",
        "    # 用于显示Output Shape\n",
        "    def summary(self):\n",
        "        x_input = layers.Input(shape = MAX_LEN)\n",
        "        output = self.call(x_input)\n",
        "        model = tf.keras.Model(inputs = x_input,outputs = output)\n",
        "        model.summary()\n",
        "    \n",
        "model = CnnModel()\n",
        "model.build(input_shape =(None,MAX_LEN))\n",
        "model.summary()"
      ],
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Model: \"functional_1\"\n",
            "_________________________________________________________________\n",
            "Layer (type)                 Output Shape              Param #   \n",
            "=================================================================\n",
            "input_1 (InputLayer)         [(None, 200)]             0         \n",
            "_________________________________________________________________\n",
            "embedding (Embedding)        (None, 200, 7)            70000     \n",
            "_________________________________________________________________\n",
            "conv_1 (Conv1D)              (None, 196, 16)           576       \n",
            "_________________________________________________________________\n",
            "pool_1 (MaxPooling1D)        (None, 98, 16)            0         \n",
            "_________________________________________________________________\n",
            "conv_2 (Conv1D)              (None, 97, 128)           4224      \n",
            "_________________________________________________________________\n",
            "pool_2 (MaxPooling1D)        (None, 48, 128)           0         \n",
            "_________________________________________________________________\n",
            "flatten (Flatten)            (None, 6144)              0         \n",
            "_________________________________________________________________\n",
            "dense (Dense)                (None, 1)                 6145      \n",
            "=================================================================\n",
            "Total params: 80,945\n",
            "Trainable params: 80,945\n",
            "Non-trainable params: 0\n",
            "_________________________________________________________________\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "tdVQyE1DvckM",
        "outputId": "4ed3d29d-9b06-45fc-cde4-b39adc27d7e9",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 323
        }
      },
      "source": [
        "# 训练模型\n",
        "\n",
        "# 打印时间分割线\n",
        "@tf.function\n",
        "def printbar():\n",
        "    today_ts = tf.timestamp()%(24*60*60)\n",
        "    \n",
        "    hour = tf.cast(today_ts//3600+8,tf.int32)%tf.constant(24)\n",
        "    minite = tf.cast((today_ts%3600)//60,tf.int32)\n",
        "    second = tf.cast(tf.floor(today_ts%60),tf.int32)\n",
        "    \n",
        "    def timeformat(m):\n",
        "        if tf.strings.length(tf.strings.format(\"{}\",m))==1:\n",
        "            return(tf.strings.format(\"0{}\",m))\n",
        "        else:\n",
        "            return(tf.strings.format(\"{}\",m))\n",
        "    \n",
        "    timestring = tf.strings.join([timeformat(hour),timeformat(minite),\n",
        "                timeformat(second)],separator = \":\")\n",
        "    tf.print(\"==========\"*8+timestring)\n",
        "\n",
        "# 训练模型通常有3种方法，内置fit方法，内置train_on_batch方法，以及自定义训练循环。此处我们通过自定义训练循环训练模型。\n",
        "optimizer = optimizers.Nadam()\n",
        "loss_func = losses.BinaryCrossentropy()\n",
        "\n",
        "train_loss = metrics.Mean(name='train_loss')\n",
        "train_metric = metrics.BinaryAccuracy(name='train_accuracy')\n",
        "\n",
        "valid_loss = metrics.Mean(name='valid_loss')\n",
        "valid_metric = metrics.BinaryAccuracy(name='valid_accuracy')\n",
        "\n",
        "\n",
        "@tf.function\n",
        "def train_step(model, features, labels):\n",
        "    with tf.GradientTape() as tape:\n",
        "        predictions = model(features,training = True)\n",
        "        loss = loss_func(labels, predictions)\n",
        "    gradients = tape.gradient(loss, model.trainable_variables)\n",
        "    optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n",
        "\n",
        "    train_loss.update_state(loss)\n",
        "    train_metric.update_state(labels, predictions)\n",
        "    \n",
        "\n",
        "@tf.function\n",
        "def valid_step(model, features, labels):\n",
        "    predictions = model(features,training = False)\n",
        "    batch_loss = loss_func(labels, predictions)\n",
        "    valid_loss.update_state(batch_loss)\n",
        "    valid_metric.update_state(labels, predictions)\n",
        "\n",
        "\n",
        "def train_model(model,ds_train,ds_valid,epochs):\n",
        "    for epoch in tf.range(1,epochs+1):\n",
        "        \n",
        "        for features, labels in ds_train:\n",
        "            train_step(model,features,labels)\n",
        "\n",
        "        for features, labels in ds_valid:\n",
        "            valid_step(model,features,labels)\n",
        "        \n",
        "        #此处logs模板需要根据metric具体情况修改\n",
        "        logs = 'Epoch={},Loss:{},Accuracy:{},Valid Loss:{},Valid Accuracy:{}' \n",
        "        \n",
        "        if epoch%1==0:\n",
        "            printbar()\n",
        "            tf.print(tf.strings.format(logs,\n",
        "            (epoch,train_loss.result(),train_metric.result(),valid_loss.result(),valid_metric.result())))\n",
        "            tf.print(\"\")\n",
        "        \n",
        "        train_loss.reset_states()\n",
        "        valid_loss.reset_states()\n",
        "        train_metric.reset_states()\n",
        "        valid_metric.reset_states()\n",
        "\n",
        "train_model(model,ds_train,ds_test,epochs = 6)"
      ],
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "================================================================================15:09:29\n",
            "Epoch=1,Loss:0.448172688,Accuracy:0.7644,Valid Loss:0.340480655,Valid Accuracy:0.8512\n",
            "\n",
            "================================================================================15:09:42\n",
            "Epoch=2,Loss:0.251793057,Accuracy:0.9,Valid Loss:0.324303776,Valid Accuracy:0.8692\n",
            "\n",
            "================================================================================15:09:55\n",
            "Epoch=3,Loss:0.179901466,Accuracy:0.93115,Valid Loss:0.369245142,Valid Accuracy:0.865\n",
            "\n",
            "================================================================================15:10:08\n",
            "Epoch=4,Loss:0.121482134,Accuracy:0.9575,Valid Loss:0.455947042,Valid Accuracy:0.8608\n",
            "\n",
            "================================================================================15:10:20\n",
            "Epoch=5,Loss:0.0720875859,Accuracy:0.9756,Valid Loss:0.607589602,Valid Accuracy:0.8568\n",
            "\n",
            "================================================================================15:10:33\n",
            "Epoch=6,Loss:0.0391695425,Accuracy:0.98765,Valid Loss:0.784351528,Valid Accuracy:0.8502\n",
            "\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "p7QYSie-voRR",
        "outputId": "01b6e8e5-2511-4712-f4f9-b0272eebfd57",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "# 评估模型\n",
        "# 通过自定义训练循环训练的模型没有经过编译，无法直接使用model.evaluate(ds_valid)方法\n",
        "def evaluate_model(model,ds_valid):\n",
        "    for features, labels in ds_valid:\n",
        "         valid_step(model,features,labels)\n",
        "    logs = 'Valid Loss:{},Valid Accuracy:{}' \n",
        "    tf.print(tf.strings.format(logs,(valid_loss.result(),valid_metric.result())))\n",
        "    \n",
        "    valid_loss.reset_states()\n",
        "    train_metric.reset_states()\n",
        "    valid_metric.reset_states()\n",
        "\n",
        "    \n",
        "evaluate_model(model,ds_test)"
      ],
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Valid Loss:0.784351528,Valid Accuracy:0.8502\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iv7Uzwldvwvm"
      },
      "source": [
        "# 使用模型\n",
        "\n",
        "可以使用以下方法:\n",
        "\n",
        "- model.predict(ds_test)\n",
        "- model(x_test)\n",
        "- model.call(x_test)\n",
        "- model.predict_on_batch(x_test)\n",
        "\n",
        "推荐优先使用model.predict(ds_test)方法，既可以对Dataset，也可以对Tensor使用。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6PsYOXXZv1xm",
        "outputId": "476b9b66-41e2-467c-80c8-1814c6b028e3",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 136
        }
      },
      "source": [
        "model.predict(ds_test)"
      ],
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "array([[0.083307  ],\n",
              "       [0.9998338 ],\n",
              "       [0.9934341 ],\n",
              "       ...,\n",
              "       [0.9997185 ],\n",
              "       [0.02555481],\n",
              "       [1.        ]], dtype=float32)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 9
        }
      ]
    }
  ]
}